{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License.\n",
    "\n",
    "# UNet++ input size constraints\n",
    "\n",
    "MONAI provides an enhanced version of UNet (``monai.networks.nets.UNet ``), which not only supports residual units, but also can use more hyperparameters (like ``strides``, ``kernel_size`` and ``up_kernel_size``) than ``monai.networks.nets.BasicUNet``. However, ``UNet`` has some constraints for both network hyperparameters and sizes of input.\n",
    "\n",
    "MONAI provides a version of UNET++ (`` monai.networks.nets.BasicUnetPlusPlus ``), with fixed num. of down-scale layer, strides of 2. The configurations you can change are: the number input and output channels, number of hidden channels (6 different layers), norm and activation, bias of convolution, dropout rate, and up-sampling model. As `UNET`, different model configurations can affect the input shape.\n",
    "\n",
    "The constraints of hyper-parameters can be found in the docstring of the network, and this tutorial is focused on how to determine a reasonable input size."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q monai-weekly"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MONAI version: 0+untagged.2891.gccd32ca\n",
      "Numpy version: 1.25.1\n",
      "Pytorch version: 2.0.1\n",
      "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n",
      "MONAI rev id: ccd32ca5e9e84562d2f388b45b6724b5c77c1f57\n",
      "MONAI __file__: /Users/<username>/Envs/monai/lib/python3.9/site-packages/monai/__init__.py\n",
      "\n",
      "Optional dependencies:\n",
      "Pytorch Ignite version: 0.4.11\n",
      "ITK version: 5.3.0\n",
      "Nibabel version: 5.1.0\n",
      "scikit-image version: 0.21.0\n",
      "scipy version: 1.11.1\n",
      "Pillow version: 10.0.0\n",
      "Tensorboard version: 2.13.0\n",
      "gdown version: 4.7.1\n",
      "TorchVision version: 0.15.2\n",
      "tqdm version: 4.65.0\n",
      "lmdb version: 1.4.1\n",
      "psutil version: 5.9.5\n",
      "pandas version: 2.0.3\n",
      "einops version: 0.6.1\n",
      "transformers version: 4.21.3\n",
      "mlflow version: 2.4.2\n",
      "pynrrd version: 1.0.0\n",
      "clearml version: 1.11.2rc0\n",
      "\n",
      "For details about installing the optional dependencies, please visit:\n",
      "    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from monai.networks.nets import BasicUnetPlusPlus\n",
    "import monai\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from typing import Dict\n",
    "\n",
    "monai.config.print_config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check UNet++ structure\n",
    "\n",
    "![](../../figures/unet++.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BasicUnetPlusPlus(\n",
    "    spatial_dims=3,\n",
    "    in_channels=3,\n",
    "    out_channels=3,\n",
    "    features=(32, 32, 64, 128, 256, 32),\n",
    "    # norm='localresponse',\n",
    "    norm=\"batch\",\n",
    ")\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Normalization\n",
    "\n",
    "UNET++ use the same `TwoConv`, `Down`, and `UpCat` as UNet. Therefore, you can referred to the `modules/UNet_input_size_constraints.ipynb` for break down analysis. For summary, the constraints for these types of normalization are:\n",
    "\n",
    "- Instance Norm: the product of spatial dimension must > 1 (not include channel and batch)\n",
    "- Batch Norm: the product of spatial dimension and batch must > 1 (not include channels). For training best interested, `batch_size` should be larger than 1\n",
    "- Local Response Norm: No constraint.\n",
    "- Other Normalization: please referred to `modules/UNet_input_size_constraints.ipynb`\n",
    "\n",
    "As for UNET++ have 4 down-sampling blocks with 2x kernel size, with no argument to change this behavior, the smallest edge we can have is `2**4 = 16`, and after the last down-sampling block, the `vector.shape  = [..., ..., 1, 1]` or (`[..., ..., 1, 1, 1]` for 3D), which will cause error for the Normalization layer.\n",
    "\n",
    "See the test code below for examples of batch norm and instance norm\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('INSTANCE', 'BATCH', 'INSTANCE_NVFUSER', 'GROUP', 'LAYER', 'LOCALRESPONSE', 'SYNCBATCH')\n"
     ]
    }
   ],
   "source": [
    "print(monai.networks.layers.factories.Norm.names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prepare model\n",
      "BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).\n",
      "BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).\n",
      "========== USING NORM LAYER: INSTANCE ==========\n",
      "__________ Changing the H dimension of 2D input __________\n",
      ">> Using Input.shape=torch.Size([1, 3, 16, 16])\n",
      ">>>> Exception: Expected more than 1 spatial element when training, got input size torch.Size([1, 256, 1, 1])\n",
      "\n",
      ">> Using Input.shape=torch.Size([1, 3, 32, 16])\n",
      "__________ Changing the batch size __________\n",
      ">> Input.shape=torch.Size([1, 3, 16, 16])\n",
      ">> Exception: Expected more than 1 spatial element when training, got input size torch.Size([1, 256, 1, 1])\n",
      "\n",
      ">> Input.shape=torch.Size([2, 3, 16, 16])\n",
      ">> Exception: Expected more than 1 spatial element when training, got input size torch.Size([2, 256, 1, 1])\n",
      "\n",
      "========== USING NORM LAYER: BATCH ==========\n",
      "__________ Changing the H dimension of 2D input __________\n",
      ">> Using Input.shape=torch.Size([1, 3, 16, 16])\n",
      ">>>> Exception: Expected more than 1 value per channel when training, got input size torch.Size([1, 256, 1, 1])\n",
      "\n",
      ">> Using Input.shape=torch.Size([1, 3, 32, 16])\n",
      "__________ Changing the batch size __________\n",
      ">> Input.shape=torch.Size([1, 3, 16, 16])\n",
      ">> Exception: Expected more than 1 value per channel when training, got input size torch.Size([1, 256, 1, 1])\n",
      "\n",
      ">> Input.shape=torch.Size([2, 3, 16, 16])\n"
     ]
    }
   ],
   "source": [
    "def make_model_with_layer(layer_norm):\n",
    "    return BasicUnetPlusPlus(\n",
    "        spatial_dims=2, in_channels=3, out_channels=1, features=(32, 32, 64, 128, 256, 32), norm=layer_norm\n",
    "    )\n",
    "\n",
    "\n",
    "def test_min_dim():\n",
    "    min_edge = 16\n",
    "    batch_size, spatial_dim, height, width = 1, 3, min_edge, min_edge\n",
    "    model_dict: Dict[str, BasicUnetPlusPlus] = {}\n",
    "    print(\"Prepare model\")\n",
    "    for norm_layer in [\"instance\", \"batch\"]:\n",
    "        model_dict[norm_layer] = make_model_with_layer(norm_layer)\n",
    "\n",
    "    # print(f\"Input dimension {(batch_size, spatial_dim, H, W)} that will cause error\")\n",
    "    for norm_layer in [\"instance\", \"batch\"]:\n",
    "        print(\"=\" * 10 + f\" USING NORM LAYER: {norm_layer.upper()} \" + \"=\" * 10)\n",
    "        model = model_dict[norm_layer]\n",
    "        print(\"_\" * 10 + \" Changing the H dimension of 2D input \" + \"_\" * 10)\n",
    "        for temp_height in [height, height * 2]:\n",
    "            try:\n",
    "                x = torch.ones(batch_size, spatial_dim, temp_height, width)\n",
    "                print(f\">> Using Input.shape={x.shape}\")\n",
    "                model(x)\n",
    "            except Exception as msg:\n",
    "                print(f\">>>> Exception: {msg}\\n\")\n",
    "\n",
    "        # print(\"Changing the batch size\")\n",
    "        print(\"_\" * 10 + \" Changing the batch size \" + \"_\" * 10)\n",
    "        for batch_size_tmp in [1, 2]:\n",
    "            try:\n",
    "                x = torch.ones(batch_size_tmp, spatial_dim, height, width)\n",
    "                print(f\">> Input.shape={x.shape}\")\n",
    "                model(x)\n",
    "            except Exception as msg:\n",
    "                print(f\">> Exception: {msg}\\n\")\n",
    "\n",
    "\n",
    "with torch.no_grad():\n",
    "    test_min_dim()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Normalization conclusion\n",
    "\n",
    "**Note:** These are lower constraint. A higher resolution input is recommended.\n",
    "\n",
    "For convention, let consider the shape of 3D model is `(B, C, D, H, W)` and for 2D model is `(B, C, H, W)`. The minimum value of any `D, H, W` is `16`, as mentioned above\n",
    "\n",
    "If you are using:\n",
    "- Batch Norm: ensure that, there are **at least one value** of `(D, H, W)` for 3D, or `(H, W)`:\n",
    "  - `>= 32` if the batch size `== 1`\n",
    "  - `>= 16` if the batch size `> 1`\n",
    "\n",
    "- Instance Norm: ensure that there are **at least one value** of `(D, H, W)` for 3D, or `(H, W)` that `>= 32`\n",
    "\n",
    "- For Local response: No constraint.\n",
    "\n",
    "- Others normalization: those norm required input shape agree with the norm's parameters, therefore you will have to research about those layer before any usage.\n",
    "\n",
    "\n",
    "**Note**: also note that, you can pass argument to normalization layer (some will result in error if you don't), check below example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BasicUNetPlusPlus features: (32, 32, 64, 128, 256, 32).\n"
     ]
    }
   ],
   "source": [
    "# NOTE: this will result in error, as lack of argument\n",
    "# model = BasicUnetPlusPlus(\n",
    "#     spatial_dims=3,\n",
    "#     in_channels=3,\n",
    "#     out_channels=3,\n",
    "#     features=(32, 32, 64, 128, 256, 32),\n",
    "#     norm='localresponse',\n",
    "# )\n",
    "\n",
    "# NOTE: this will work fine\n",
    "model = BasicUnetPlusPlus(\n",
    "    spatial_dims=3,\n",
    "    in_channels=3,\n",
    "    out_channels=3,\n",
    "    features=(32, 32, 64, 128, 256, 32),\n",
    "    norm=(\"localresponse\", {\"size\": 10}),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pooling (down-sampling) and Up-sampling\n",
    "\n",
    "While all the internal padding is handled to ensure the input spatial shape is the same as the output spatial shape, you still need to aware about these implicit processing:\n",
    "\n",
    "- For down-scale path: MaxPooling with `kernel_size` and `strides` = 2. Any odd size will have the last element ignored. See below code block for the pooling with odd shape\n",
    "- For up-scale path, depend on the up-samping mode, using `UpCat` module with auto padding (`replicate` mode) the up-sampled path to the same input shape as input path. See this code of the `UpCat`\n",
    "\n",
    "```[python]\n",
    "def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]):\n",
    "        \"\"\"\n",
    "\n",
    "        Args:\n",
    "            x: features to be upsampled.\n",
    "            x_e: optional features from the encoder, if None, this branch is not in use.\n",
    "        \"\"\"\n",
    "        x_0 = self.upsample(x)\n",
    "\n",
    "        if x_e is not None and torch.jit.isinstance(x_e, torch.Tensor):\n",
    "            if self.is_pad: # alway True\n",
    "                # handling spatial shapes due to the 2x max-pooling with odd edge lengths.\n",
    "                dimensions = len(x.shape) - 2\n",
    "                sp = [0] * (dimensions * 2)\n",
    "                for i in range(dimensions):\n",
    "                    if x_e.shape[-i - 1] != x_0.shape[-i - 1]:\n",
    "                        sp[i * 2 + 1] = 1\n",
    "                x_0 = torch.nn.functional.pad(x_0, sp, \"replicate\")\n",
    "            x = self.convs(torch.cat([x_e, x_0], dim=1))  # input channels: (cat_chns + up_chns)\n",
    "        else:\n",
    "            x = self.convs(x_0)\n",
    "\n",
    "        return x\n",
    "\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "un_pool.shape=torch.Size([1, 1, 5, 5]), pooled.shape=torch.Size([1, 1, 2, 2])\n",
      "tensor([[[[ 1.,  2.,  3.,  4.,  5.],\n",
      "          [ 2.,  4.,  6.,  8., 10.],\n",
      "          [ 3.,  6.,  9., 12., 15.],\n",
      "          [ 4.,  8., 12., 16., 20.],\n",
      "          [ 5., 10., 15., 20., 25.]]]])\n",
      "tensor([[[[ 4.,  8.],\n",
      "          [ 8., 16.]]]])\n"
     ]
    }
   ],
   "source": [
    "# Example about the pooling with odd shape\n",
    "un_pool = (\n",
    "    torch.Tensor(\n",
    "        [\n",
    "            [1, 2, 3, 4, 5],\n",
    "            [1, 2, 3, 4, 5],\n",
    "            [1, 2, 3, 4, 5],\n",
    "            [1, 2, 3, 4, 5],\n",
    "            [1, 2, 3, 4, 5],\n",
    "        ]\n",
    "    )\n",
    "    * torch.Tensor([1, 2, 3, 4, 5])[..., None]\n",
    ")\n",
    "un_pool = un_pool[None, None, ...]\n",
    "\n",
    "pooled = nn.MaxPool2d(kernel_size=2)(un_pool)\n",
    "print(f\"{un_pool.shape=}, {pooled.shape=}\")\n",
    "print(un_pool)\n",
    "print(pooled)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Up/Down sampling conclusion\n",
    "\n",
    "It's best to keep your input spatial dimension (`W, D, H`) a multiple of 16 (`16 * x`). They will maximize your data-usage as no padding will be required. The input shape also should be greater than 32 for safety against normalization, as mentioned above."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Features argument\n",
    "\n",
    "Features argument decide the num. of output channels for each levels of convolution stacks. There are 6 values to fill, anymore or less will result in an error.\n",
    "\n",
    "In the original paper, they use 2 settings:\n",
    "- UNet's : `features=(32, 64, 128, 256, 512, 32)`\n",
    "- Wide UNet's : `features=(32, 70, 140, 280, 560, 32)`\n",
    "\n",
    "Note that, `features[5]` is used only for the last up-sampling + convolution block. Compare to the paper, MONAI's implementation implies that `features[5] = features[0]`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BasicUNetPlusPlus features: (32, 64, 128, 256, 512, 32).\n"
     ]
    }
   ],
   "source": [
    "model = BasicUnetPlusPlus(\n",
    "    spatial_dims=3,\n",
    "    in_channels=3,\n",
    "    out_channels=3,\n",
    "    features=(32, 64, 128, 256, 512, 32),\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
